#! /usr/bin/python3
# -*- coding: utf-8 -*-


import numpy as np
import matplotlib.pyplot as plt



def Xprime(X, c, d, r):
    """ Renvoie X' """
    return np.array([
        -c*X[0]*X[3], #S'
        d*X[3], #D'
        r*X[3], #R'
        c*X[0]*X[3] -  (d+r)*X[3]#M'
    ])

def euler( M0, c, r, d, tf, t0=0, D0=0,R0=0, dt=1):
    
    nTau = int(tau/dt) #nb de pas de temps que vaut τ
    X=np.array([1-M0-D0-R0, D0, R0, M0])
    tX=[X]
    t=t0
    tt=[t]
    
    while t<tf:
        Xp= Xprime(X, c, d, r)
        #print(Xp)
        X=X+dt*Xp
        tX.append(X)
        tt.append(t)
        t+=dt

    return tt, tX


def dessin( M0, c, r, d, tf, t0=0, D0=0,R0=0, dt=1):
    tt, tX =euler( M0, c, r, d, tf, t0=t0, D0=D0,R0=R0, dt=dt)
    plt.plot(tt, [X[0] for X in tX], label="sains")
    plt.plot(tt, [X[1] for X in tX],label="décédés")
    plt.plot(tt, [X[2] for X in tX], label="rétablis")
    plt.plot(tt, [X[3] for X in tX], label="malades")
    
    plt.xlabel("Temps (jours)")
    plt.ylabel("Proportion de population")
    plt.legend()
    plt.show()
    print( f"{int(tX[-1][1]*70e6)} morts")




### Calcul des paramètres ###


tau = 10 # durée moyenne de la maladie
taux_mortalité=0.01
d_estimé = taux_mortalité/tau
r_estimé = 1/tau - d_estimé


chemin_confirmés="time_series_covid19_confirmed_global.csv"
chemin_rétablis="time_series_covid19_recovered_global.csv"
ligne_france=118
# La ligne 0 donne les noms des colonnes, en particulier les dates
#début au 22 janvier
avant_confinement = 9+29+10 # Nb de jours entre le début des mesures et le début du confinement en France.


def enlève_année(c):
    """Entrée : chaîne de caractère représentant une date au format "j/m/a", l'année sur deux chiffres.
     Sortie : la date sans l'année."""
    return c[:-3]


def extrait_donnée(début, fin, chemin=chemin_confirmés, num_ligne=ligne_france,  tout=False, debug=False):
    """
    Entrée : deux indices début et fin
    Sortie : deux tableaux : (dates, nb de malades officiellement recencés)
             On ne renvoie que les données qui correspondent à une date située entre début et fin après le 22 janvier.
    """
    entrée=open(chemin)
    if tout: début, fin = 0, -5
    premièreLigne=list(map(enlève_année,entrée.readline().split(",")[4+début:4+fin]))
    for _ in range(num_ligne-2):
        entrée.readline()
    ligne=entrée.readline().split(",")
    if debug: print(f"Lecture du fichier {chemin}, voici le début de la ligne : {ligne[:4]}")
    d=list(map(int, ligne[4+début:4+fin]))
    return premièreLigne, d



def dessins_données_passé(début, fin, chemin=chemin_confirmés, tout=False):
    premièreLigne, d = extrait_donnée(début, fin, chemin=chemin, tout=tout)
    plt.plot(premièreLigne, d, label="cas confirmés" )
    plt.xlabel("date")
    plt.ylabel("nombre de cas")
    plt.legend()
    plt.title("nombre de cas en fonction de la date")
    plt.show()

    plt.plot(premièreLigne, np.log(d), label="ln(cas confirmés)" )
    plt.title("ln(nombre de cas) en fonction de la date")
    plt.xlabel("date")
    plt.legend()
    plt.show()


# par lecture graphique, c_initial : du 26/02  au 15/03
#                        c_confinement : du 1/04 au 19/04
# début des mesures au 22 janvier
déb_pour_c_initial = 10 + 25
fin_pour_c_initial = 10 + 29 + 14
déb_pour_c_confinement = 10 + 29 + 31
fin_pour_c_confinement = 10 + 29 + 31 + 18


def calcule_lambda_et_ln_M0(début, fin, chemin=chemin_confirmés):
    premièreLigne, d = extrait_donnée(début, fin, chemin=chemin)
    return np.polyfit(range(len(d)), np.log(np.array(d)/7e7), 1)
lambda_exp, ln_M0= calcule_lambda_et_ln_M0(déb_pour_c_confinement, fin_pour_c_confinement  )
c_con = lambda_exp + d_estimé + r_estimé # la constante c en temps de confinement
M0=np.exp(ln_M0)
print(c_con,M0)

# bonus...
def calcule_lambda_et_M0_et_R0(début, fin, chemin=chemin_confirmés):
    """ Pour un départ plus avancé où R0 ne serait pas négligeable."""
    
    premièreLigne, d = extrait_donnée(début, fin, chemin=chemin)
    _, r = extrait_donnée(début,fin, num_ligne=110, chemin=chemin_rétablis, debug=False)
    
    a,b= np.polyfit(range(len(d)), np.log(np.array(d)/7e7), 1)
    return (a,b,r[0]/7e7)


lambda_exp, ln_M0, R0 = calcule_lambda_et_M0_et_R0(déb_pour_c_confinement, fin_pour_c_confinement  )
c_con = lambda_exp + d_estimé + r_estimé
M0=np.exp(ln_M0)

#print(calcule_lambda_et_M0_et_R0(déb_pour_c_confinement, fin_pour_c_confinement  ))

lambda_init, ln_Mini = calcule_lambda_et_ln_M0(déb_pour_c_initial, fin_pour_c_initial  )
c_ini = lambda_init + d_estimé + r_estimé
M_ini = np.exp(ln_Mini)






#### En prenant en compte capacité des hopitaux et fin du confinement
# Avouons-le : ça devient du doigt mouillé ici !

def modé2( M0, tf, Mc, tfc, t0=0, D0=0,R0=0, dt=1):

    def d(M):
        if M>Mc : return 2*d_estimé
        else :return d_estimé
    def c(t):
        if t>tfc:return c_ini
        else : return c_con
    
    X=np.array([1-M0-D0-R0, D0, R0, M0])
    tX=[X]
    t=t0
    tt=[t]
    
    while t<tf:
        Xp=Xprime(X, c(t), d(X[3]), r_estimé )
        #print(Xp)
        X=X+dt*Xp
        tX.append(X)
        tt.append(t)
        t+=dt

    return tt, tX


def dessin2( M0, tf, Mc, tfc, t0=0, D0=0, R0=0, dt=1):
    tt, tX =modé2(M0, tf, Mc, tfc, t0=t0, D0=D0,R0=R0, dt=dt)
    #plt.plot(tt, [X[0] for X in tX], label="sains")
    plt.plot(tt, [X[1] for X in tX],label="décédés")
    plt.plot(tt, [X[2] for X in tX], label="rétablis")
    plt.plot(tt, [X[3] for X in tX], label="malades")
    plt.plot( [tt[0],tt[-1]], [Mc,Mc], color="red",linestyle="--")
    
    plt.xlabel("Temps (jours)")
    plt.ylabel("Proportion de population")
    plt.legend()
    plt.show()
